import re
import torch
import random

from typing import Any, Tuple
from torchvision import datasets


def image_coordinate(path: str):
    return round(float('0.' + re.findall(r'\d+', path)[-1]), 3)

def all_coordinates(path: str):
    result = []
    for i in range(4):
        result.append(round(float('0.' + re.findall(r'\d+', path)[-7 + 2*i]), 3))
    return result

def image_class(path: str):
    return int(re.findall(r'\d+', path)[1])

def image_coordinate_joint(path: str):
    return int(re.findall(r'\d+', path)[-2])

def image_target_joint(path: str):
    targs = re.findall(r'\d+', path)[0]
    return [int(targs[i]) for i in range(2)]

class RegressionImageFolder(datasets.ImageFolder):

    def __init__(
        self, root: str, orig=False, batch_size=None, dist=True, **kwargs: Any
    ) -> None:
        super().__init__(root, **kwargs)
        self.batch_size = batch_size
        self.dist = dist
        paths, _ = zip(*self.imgs)
        regresspaths = []
        originalpaths = []
        for path in paths:
            if 'original' in path:
                originalpaths.append(path)
            else:
                regresspaths.append(path)
        if orig:
            paths = originalpaths
        else:
            paths = regresspaths
        self.regresstargets = [image_coordinate(path) for path in paths]
        self.targets = [image_class(path) for path in paths]
        self.samples = self.imgs = list(zip(paths, self.targets, self.regresstargets))

        rng = random.Random(7)
        c = list(zip(self.regresstargets, self.targets, self.samples))
        rng.shuffle(c)
        self.regresstargets, self.targets, self.samples = zip(*c)


        if batch_size is not None:
            L = len(self.samples)
            batchsamples = []
            batchtargets = []
            batchregresstargets = []
            sbatch = []
            tbatch = []
            rtbatch = []
            for i in range(L):
                if i % batch_size == 0 and i != 0:
                    batchsamples.append(sbatch)
                    batchtargets.append(torch.tensor(tbatch))
                    batchregresstargets.append(torch.tensor(rtbatch))
                    sbatch = []
                    tbatch = []
                    rtbatch = []
                sbatch.append(self.samples[i])
                tbatch.append(self.targets[i])
                rtbatch.append(self.regresstargets[i])
            self.samples = batchsamples
            self.targets = batchtargets
            self.regresstargets = batchregresstargets


    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) or (sample, target, distance) where target is class_index of the target class.
        """
        if self.batch_size is None:
            path, target = self.samples[index]
            sample = self.loader(path)
            if self.transform is not None:
                sample = self.transform(sample)
            if self.target_transform is not None:
                target = self.target_transform(target)

            return sample, target
        else:
            paths = [x[0] for x in self.samples[index]]
            targets = torch.tensor([x[1] for x in self.samples[index]])
            distances = torch.tensor([x[2] for x in self.samples[index]])
            samples = []
            for path in paths:
                if self.transform is not None:
                    samples.append(self.transform(self.loader(path)))
                else:
                    samples.append(self.loader(path))
            samples = torch.stack(samples)
            if self.dist:
                return samples, targets, distances
            else:
                return samples, targets


class JointImageFolder(datasets.ImageFolder):

    def __init__(
            self, root: str, batch_size=None, dist=True, **kwargs: Any
    ) -> None:
        super().__init__(root, **kwargs)
        self.batch_size = batch_size
        self.dist = dist
        paths, _ = zip(*self.imgs)
        self.targets = [image_target_joint(path) for path in paths]
        self.regresstargets = [all_coordinates(path) for path in paths]
        self.samples = self.imgs = list(zip(paths, self.targets, self.regresstargets))

        rng = random.Random(7)
        c = list(zip(self.regresstargets, self.targets, self.samples))
        rng.shuffle(c)
        self.regresstargets, self.targets, self.samples = zip(*c)

        if batch_size is not None:
            L = len(self.samples)
            batchsamples = []
            batchtargets = []
            batchdistances = []
            sbatch = []
            tbatch = []
            dbtach = []
            for i in range(L):
                if i % batch_size == 0 and i != 0:
                    batchsamples.append(sbatch)
                    batchtargets.append(torch.tensor(tbatch))
                    batchdistances.append(torch.tensor(dbtach))
                    sbatch = []
                    tbatch = []
                    dbtach = []
                sbatch.append(self.samples[i])
                tbatch.append(self.targets[i])
                dbtach.append(self.regresstargets[i])
            self.samples = batchsamples
            self.targets = batchtargets
            self.regresstargets = batchdistances

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) or (sample, target, distance) where target is class_index of the target class.
        """
        if self.batch_size is None:
            path, target = self.samples[index]
            sample = self.loader(path)
            if self.transform is not None:
                sample = self.transform(sample)
            if self.target_transform is not None:
                target = self.target_transform(target)

            return sample, target
        else:
            paths = [x[0] for x in self.samples[index]]
            targets = torch.tensor([x[1] for x in self.samples[index]])
            distances = torch.tensor([x[2] for x in self.samples[index]])
            samples = []
            for path in paths:
                if self.transform is not None:
                    samples.append(self.transform(self.loader(path)))
                else:
                    samples.append(self.loader(path))
            samples = torch.stack(samples)
            if self.dist:
                return samples, targets, distances
            else:
                return samples, targets
